package netty;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.string.StringDecoder;
import io.netty.handler.codec.string.StringEncoder;
import listener.ITaskFinishListener;
import netty.handler.RpcResponseHandler;
import rpc.CompletableRpcTask;
import rpc.TaskMessageWrap;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

/**
 * @author : zGame
 * @version V1.0
 * @Project: simple-rpc
 * @Package netty
 * @Description: rpc使用netty进行通信，对netty channel进行封装
 * @date Date : 2022年02月28日 10:11
 */
public class NettyClient implements ITaskFinishListener {
    
    private Map<String, CompletableRpcTask> rpcTaskMap = new ConcurrentHashMap<>();
    
    private Channel channel;
    
    /**
     * 同步获取rpc结果,没有超时时间
     * 调用get方法时,当前线程被被阻塞到RpcTask上,直到收到rpc远程结果,才会被唤醒
     * @param taskMessageWrap
     * @param rpcTask
     * @param <T>
     * @return
     * @throws ExecutionException
     * @throws InterruptedException
     */
    public <T> T get(TaskMessageWrap taskMessageWrap, CompletableRpcTask rpcTask) throws ExecutionException, InterruptedException {
        rpcTaskMap.put(rpcTask.getTaskId(), rpcTask);
        channel.writeAndFlush(taskMessageWrap.toString());
        return (T) rpcTask.get();
    }
    
    /**
     * 同步获取rpc结果,有超时时间
     * @param taskMessageWrap
     * @param rpcTask
     * @param time
     * @param timeUnit
     * @param <T>
     * @return
     * @throws ExecutionException
     * @throws InterruptedException
     * @throws TimeoutException
     */
    public <T> T get(TaskMessageWrap taskMessageWrap, CompletableRpcTask rpcTask,long time, TimeUnit timeUnit)
            throws ExecutionException, InterruptedException, TimeoutException {
        rpcTaskMap.put(rpcTask.getTaskId(), rpcTask);
        channel.writeAndFlush(taskMessageWrap.toString());
        return (T) rpcTask.get(time, timeUnit);
    }
    
    /**
     * 创建nettyClient
     * @param name
     * @param host
     * @param port
     */
    public NettyClient(String name, String host, int port) {
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        Bootstrap b = new Bootstrap(); // (1)
        b.group(workerGroup); // (2)
        b.channel(NioSocketChannel.class); // (3)
        b.option(ChannelOption.SO_KEEPALIVE, true).option(ChannelOption.SO_REUSEADDR, true); // (4)
        b.handler(new ChannelInitializer<SocketChannel>() {
            @Override
            public void initChannel(SocketChannel ch) throws Exception {
                ch.pipeline().addLast(new StringDecoder());
                ch.pipeline().addLast(new StringEncoder());
                ch.pipeline().addLast(new RpcResponseHandler(rpcTaskMap));
            }
        });
        try {
            ChannelFuture f = b.connect(host, port).sync();
            channel = f.channel();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
    
    /**
     * rpc任务完成监听方法
     * @param type 0 意外中断取消 1超时结果 2正常完成
     * @param taskId
     */
    @Override
    public void taskFinish(byte type, String taskId) {
        System.out.println("taskFinish type:" + type + " taskId:" + taskId);
        CompletableRpcTask v = rpcTaskMap.remove(taskId);
        if (v == null){
            System.out.println("rpc is timeOut by self:" + type + " taskId:" + taskId);
        }
    }
}
